{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import xgboost as xgb\n",
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets.samples_generator import make_classification\n",
    "# X为样本特征，y为样本类别输出， 共10000个样本，每个样本20个特征，输出有2个类别，没有冗余特征，每个类别一个簇\n",
    "X, y = make_classification(n_samples=10000, n_features=20, n_redundant=0,\n",
    "                             n_clusters_per_class=1, n_classes=2, flip_y=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7500, 20)\n",
      "(7500,)\n",
      "(2500, 20)\n",
      "(2500,)\n"
     ]
    }
   ],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)\n",
    "print (X_train.shape)\n",
    "print (y_train.shape)\n",
    "print (X_test.shape)\n",
    "print (y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>XGBoost 使用原生API</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtrain = xgb.DMatrix(X_train,y_train)\n",
    "dtest = xgb.DMatrix(X_test,y_test)\n",
    "param = {'max_depth':5, 'eta':0.5, 'verbosity':1, 'objective':'binary:logistic'}\n",
    "raw_model = xgb.train(param, dtrain, num_boost_round=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9664\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "pred_train_raw = raw_model.predict(dtrain)\n",
    "for i in range(len(pred_train_raw)):\n",
    "    if pred_train_raw[i] > 0.5:\n",
    "         pred_train_raw[i]=1\n",
    "    else:\n",
    "        pred_train_raw[i]=0               \n",
    "print (accuracy_score(dtrain.get_label(), pred_train_raw))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9408\n"
     ]
    }
   ],
   "source": [
    "pred_test_raw = raw_model.predict(dtest)\n",
    "for i in range(len(pred_test_raw)):\n",
    "    if pred_test_raw[i] > 0.5:\n",
    "         pred_test_raw[i]=1\n",
    "    else:\n",
    "        pred_test_raw[i]=0               \n",
    "print (accuracy_score(dtest.get_label(), pred_test_raw))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>XGBoost 使用sklearn wrapper，仍然使用原始API的参数</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tvalidation_0-error:0.0636\n",
      "Will train until validation_0-error hasn't improved in 10 rounds.\n",
      "[1]\tvalidation_0-error:0.062\n",
      "[2]\tvalidation_0-error:0.0624\n",
      "[3]\tvalidation_0-error:0.062\n",
      "[4]\tvalidation_0-error:0.062\n",
      "[5]\tvalidation_0-error:0.062\n",
      "[6]\tvalidation_0-error:0.062\n",
      "[7]\tvalidation_0-error:0.062\n",
      "[8]\tvalidation_0-error:0.062\n",
      "[9]\tvalidation_0-error:0.0608\n",
      "[10]\tvalidation_0-error:0.0608\n",
      "[11]\tvalidation_0-error:0.0608\n",
      "[12]\tvalidation_0-error:0.0608\n",
      "[13]\tvalidation_0-error:0.0604\n",
      "[14]\tvalidation_0-error:0.0604\n",
      "[15]\tvalidation_0-error:0.0604\n",
      "[16]\tvalidation_0-error:0.0604\n",
      "[17]\tvalidation_0-error:0.0604\n",
      "[18]\tvalidation_0-error:0.0608\n",
      "[19]\tvalidation_0-error:0.0608\n",
      "[20]\tvalidation_0-error:0.06\n",
      "[21]\tvalidation_0-error:0.06\n",
      "[22]\tvalidation_0-error:0.06\n",
      "[23]\tvalidation_0-error:0.0592\n",
      "[24]\tvalidation_0-error:0.0588\n",
      "[25]\tvalidation_0-error:0.0588\n",
      "[26]\tvalidation_0-error:0.0584\n",
      "[27]\tvalidation_0-error:0.0584\n",
      "[28]\tvalidation_0-error:0.0584\n",
      "[29]\tvalidation_0-error:0.0584\n",
      "[30]\tvalidation_0-error:0.0584\n",
      "[31]\tvalidation_0-error:0.0584\n",
      "[32]\tvalidation_0-error:0.0584\n",
      "[33]\tvalidation_0-error:0.0584\n",
      "[34]\tvalidation_0-error:0.0584\n",
      "[35]\tvalidation_0-error:0.0584\n",
      "[36]\tvalidation_0-error:0.0584\n",
      "Stopping. Best iteration:\n",
      "[26]\tvalidation_0-error:0.0584\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, eta=0.5, gamma=0,\n",
       "       learning_rate=0.1, max_delta_step=0, max_depth=5,\n",
       "       min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,\n",
       "       nthread=None, objective='binary:logistic', random_state=0,\n",
       "       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,\n",
       "       silent=None, subsample=1, verbosity=1)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn_model_raw = xgb.XGBClassifier(**param)\n",
    "sklearn_model_raw.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=\"error\",\n",
    "        eval_set=[(X_test, y_test)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>XGBoost 使用sklearn wrapper，使用sklearn风格的参数(推荐)</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sklearn_model_new = xgb.XGBClassifier(max_depth=5,learning_rate= 0.5, verbosity=1, objective='binary:logistic',random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tvalidation_0-error:0.0636\n",
      "Will train until validation_0-error hasn't improved in 10 rounds.\n",
      "[1]\tvalidation_0-error:0.0624\n",
      "[2]\tvalidation_0-error:0.0604\n",
      "[3]\tvalidation_0-error:0.0592\n",
      "[4]\tvalidation_0-error:0.0592\n",
      "[5]\tvalidation_0-error:0.0584\n",
      "[6]\tvalidation_0-error:0.058\n",
      "[7]\tvalidation_0-error:0.0588\n",
      "[8]\tvalidation_0-error:0.0588\n",
      "[9]\tvalidation_0-error:0.0588\n",
      "[10]\tvalidation_0-error:0.0588\n",
      "[11]\tvalidation_0-error:0.0588\n",
      "[12]\tvalidation_0-error:0.0588\n",
      "[13]\tvalidation_0-error:0.058\n",
      "[14]\tvalidation_0-error:0.0584\n",
      "[15]\tvalidation_0-error:0.0584\n",
      "[16]\tvalidation_0-error:0.0584\n",
      "Stopping. Best iteration:\n",
      "[6]\tvalidation_0-error:0.058\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.5,\n",
       "       max_delta_step=0, max_depth=5, min_child_weight=1, missing=None,\n",
       "       n_estimators=100, n_jobs=1, nthread=None,\n",
       "       objective='binary:logistic', random_state=1, reg_alpha=0,\n",
       "       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,\n",
       "       subsample=1, verbosity=1)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn_model_new.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=\"error\",\n",
    "        eval_set=[(X_test, y_test)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>使用sklearn网格搜索调参</h1>\n",
    "\n",
    "一般固定步长，先调好框架参数n_estimators，再调弱学习器参数max_depth，min_child_weight,gamma等，接着调正则化相关参数subsample，colsample_byXXX, reg_alpha以及reg_lambda,最后固定前面调好的参数，来调步长learning_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=None, error_score='raise',\n",
       "       estimator=XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.5,\n",
       "       max_delta_step=0, max_depth=5, min_child_weight=1, missing=None,\n",
       "       n_estimators=100, n_jobs=1, nthread=None,\n",
       "       objective='binary:logistic', random_state=1, reg_alpha=0,\n",
       "       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,\n",
       "       subsample=1, verbosity=1),\n",
       "       fit_params=None, iid=True, n_jobs=1,\n",
       "       param_grid={'max_depth': [4, 5, 6], 'n_estimators': [5, 10, 20]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring=None, verbose=0)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gsCv = GridSearchCV(sklearn_model_new,\n",
    "                   {'max_depth': [4,5,6],\n",
    "                    'n_estimators': [5,10,20]})\n",
    "gsCv.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9533333333333334\n",
      "{'max_depth': 4, 'n_estimators': 10}\n"
     ]
    }
   ],
   "source": [
    "print(gsCv.best_score_)\n",
    "print(gsCv.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=None, error_score='raise',\n",
       "       estimator=XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,\n",
       "       max_delta_step=0, max_depth=4, min_child_weight=1, missing=None,\n",
       "       n_estimators=10, n_jobs=1, nthread=None,\n",
       "       objective='binary:logistic', random_state=1, reg_alpha=0,\n",
       "       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,\n",
       "       subsample=1, verbosity=1),\n",
       "       fit_params=None, iid=True, n_jobs=1,\n",
       "       param_grid={'learning_rate ': [0.3, 0.5, 0.7]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring=None, verbose=0)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn_model_new2 = xgb.XGBClassifier(max_depth=4,n_estimators=10,verbosity=1, objective='binary:logistic',random_state=1)\n",
    "gsCv2 = GridSearchCV(sklearn_model_new2, \n",
    "                   {'learning_rate ': [0.3,0.5,0.7]})\n",
    "gsCv2.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9516\n",
      "{'learning_rate ': 0.3}\n"
     ]
    }
   ],
   "source": [
    "print(gsCv2.best_score_)\n",
    "print(gsCv2.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tvalidation_0-error:0.062\n",
      "Will train until validation_0-error hasn't improved in 10 rounds.\n",
      "[1]\tvalidation_0-error:0.0592\n",
      "[2]\tvalidation_0-error:0.0608\n",
      "[3]\tvalidation_0-error:0.0608\n",
      "[4]\tvalidation_0-error:0.0608\n",
      "[5]\tvalidation_0-error:0.0604\n",
      "[6]\tvalidation_0-error:0.0592\n",
      "[7]\tvalidation_0-error:0.0588\n",
      "[8]\tvalidation_0-error:0.0588\n",
      "[9]\tvalidation_0-error:0.0588\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
       "       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.3,\n",
       "       max_delta_step=0, max_depth=4, min_child_weight=1, missing=None,\n",
       "       n_estimators=10, n_jobs=1, nthread=None,\n",
       "       objective='binary:logistic', random_state=0, reg_alpha=0,\n",
       "       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,\n",
       "       subsample=1, verbosity=1)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn_model_new2 = xgb.XGBClassifier(max_depth=4,learning_rate= 0.3, verbosity=1, objective='binary:logistic',n_estimators=10)\n",
    "sklearn_model_new2.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=\"error\",\n",
    "        eval_set=[(X_test, y_test)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9412\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [

     ]
    }
   ],
   "source": [
    "pred_test_new = sklearn_model_new2.predict(X_test)\n",
    "print (accuracy_score(dtest.get_label(), pred_test_new))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
